{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute this cell to install dependencies\n",
    "%pip install sf-hamilton[visualization] pandas scikit-learn numpy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MPG Simple Target [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/dagworks-inc/hamilton/blob/main/examples/hamilton-tutorials/mpg-translation/MPGSimpleTarget.ipynb) [![GitHub badge](https://img.shields.io/badge/github-view_source-2b3137?logo=github)](https://github.com/apache/hamilton/blob/main/examples/hamilton-tutorials/mpg-translation/MPGSimpleTarget.ipynb)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:32:57.945110Z",
     "start_time": "2024-07-20T17:32:50.323887Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "f7ca0a2e-99c4-49de-af45-c8c4bddf5685",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/stefankrawczyk/.pyenv/versions/knowledge_retrieval-py39/lib/python3.9/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from hamilton import driver\n",
    "from IPython.display import HTML, display"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:32:57.962238Z",
     "start_time": "2024-07-20T17:32:57.947142Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "45fcd1cf-5dee-4d3c-b598-823c82654805",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "%load_ext hamilton.plugins.jupyter_magic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:00.825770Z",
     "start_time": "2024-07-20T17:37:00.183488Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "155ea802-aef6-4d5c-b264-d9ec5b57c733",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"512pt\" height=\"195pt\"\n",
       " viewBox=\"0.00 0.00 511.65 194.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 190.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-190.8 507.65,-190.8 507.65,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"8,-103.8 8,-178.8 92.85,-178.8 92.85,-103.8 8,-103.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"50.43\" y=\"-161.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- mpg_df -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>mpg_df</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M83.35,-93.6C83.35,-93.6 17.5,-93.6 17.5,-93.6 11.5,-93.6 5.5,-87.6 5.5,-81.6 5.5,-81.6 5.5,-42 5.5,-42 5.5,-36 11.5,-30 17.5,-30 17.5,-30 83.35,-30 83.35,-30 89.35,-30 95.35,-36 95.35,-42 95.35,-42 95.35,-81.6 95.35,-81.6 95.35,-87.6 89.35,-93.6 83.35,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"25.68\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n",
       "<text text-anchor=\"start\" x=\"16.3\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>data_sets</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M196.2,-93.6C196.2,-93.6 136.35,-93.6 136.35,-93.6 130.35,-93.6 124.35,-87.6 124.35,-81.6 124.35,-81.6 124.35,-42 124.35,-42 124.35,-36 130.35,-30 136.35,-30 136.35,-30 196.2,-30 196.2,-30 202.2,-30 208.2,-36 208.2,-42 208.2,-42 208.2,-81.6 208.2,-81.6 208.2,-87.6 202.2,-93.6 196.2,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"135.15\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n",
       "<text text-anchor=\"start\" x=\"155.78\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- mpg_df&#45;&gt;data_sets -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>mpg_df&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M95.77,-61.8C101.33,-61.8 107.06,-61.8 112.73,-61.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"112.7,-65.3 122.7,-61.8 112.7,-58.3 112.7,-65.3\"/>\n",
       "</g>\n",
       "<!-- evaluated_model -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>evaluated_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M491.65,-93.6C491.65,-93.6 383.8,-93.6 383.8,-93.6 377.8,-93.6 371.8,-87.6 371.8,-81.6 371.8,-81.6 371.8,-42 371.8,-42 371.8,-36 377.8,-30 383.8,-30 383.8,-30 491.65,-30 491.65,-30 497.65,-30 503.65,-36 503.65,-42 503.65,-42 503.65,-81.6 503.65,-81.6 503.65,-87.6 497.65,-93.6 491.65,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"382.6\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n",
       "<text text-anchor=\"start\" x=\"427.23\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M208.56,-69.39C217.94,-70.81 227.89,-72.07 237.2,-72.8 283.99,-76.46 295.95,-75.53 342.8,-72.8 348.44,-72.47 354.27,-72.03 360.12,-71.52\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"360.25,-75.02 369.87,-70.59 359.58,-68.05 360.25,-75.02\"/>\n",
       "</g>\n",
       "<!-- linear_model -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>linear_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M330.8,-63.6C330.8,-63.6 249.2,-63.6 249.2,-63.6 243.2,-63.6 237.2,-57.6 237.2,-51.6 237.2,-51.6 237.2,-12 237.2,-12 237.2,-6 243.2,0 249.2,0 249.2,0 330.8,0 330.8,0 336.8,0 342.8,-6 342.8,-12 342.8,-12 342.8,-51.6 342.8,-51.6 342.8,-57.6 336.8,-63.6 330.8,-63.6\"/>\n",
       "<text text-anchor=\"start\" x=\"248\" y=\"-40.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n",
       "<text text-anchor=\"start\" x=\"279.5\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;linear_model -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;linear_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M208.49,-51.64C214.03,-50.28 219.83,-48.85 225.65,-47.41\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"226.48,-50.81 235.36,-45.02 224.81,-44.02 226.48,-50.81\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M343.2,-42.55C348.77,-43.69 354.52,-44.88 360.3,-46.07\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"359.43,-49.46 369.93,-48.05 360.84,-42.61 359.43,-49.46\"/>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M72.85,-148.1C72.85,-148.1 28,-148.1 28,-148.1 22,-148.1 16,-142.1 16,-136.1 16,-136.1 16,-123.5 16,-123.5 16,-117.5 22,-111.5 28,-111.5 28,-111.5 72.85,-111.5 72.85,-111.5 78.85,-111.5 84.85,-117.5 84.85,-123.5 84.85,-123.5 84.85,-136.1 84.85,-136.1 84.85,-142.1 78.85,-148.1 72.85,-148.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"50.43\" y=\"-124\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x14b938ca0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%%cell_to_module pipeline --display\n",
    "# when done you can write to file and then load it as a module normally\n",
    "# add -w to do so\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.metrics import mean_absolute_error\n",
    "\n",
    "\n",
    "def mpg_df() -> pd.DataFrame:\n",
    "    url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data'\n",
    "    column_names = ['MPG', 'Cylinders', 'Displacement', 'Horsepower', 'Weight',\n",
    "                    'Acceleration', 'Model Year', 'Origin']\n",
    "\n",
    "    raw_dataset = pd.read_csv(url, names=column_names,\n",
    "                            na_values='?', comment='\\t',\n",
    "                            sep=' ', skipinitialspace=True)\n",
    "\n",
    "    ## some schema manipulation\n",
    "    _mpg_df = raw_dataset.rename(columns={\"Model Year\": \"ModelYear\"})\n",
    "    return _mpg_df\n",
    "\n",
    "def data_sets(mpg_df: pd.DataFrame) -> dict:\n",
    "    # Do some feature engineering / data cleaning to create the data sets\n",
    "    # one hot encode -- we know the encoding here.\n",
    "    for value, country in {1: \"USA\", 2: \"Europe\", 3: \"Japan\"}.items():\n",
    "        mpg_df[country] = np.where(mpg_df[\"Origin\"] == value, 1, 0)\n",
    "    raw_dataset = mpg_df.dropna()\n",
    "    # create data sets\n",
    "    train_test_split = 0.8\n",
    "    seed = 123\n",
    "    # split the pandas dataframe into train and test\n",
    "    train_dataset = raw_dataset.sample(frac=train_test_split, random_state=seed)\n",
    "    test_dataset = raw_dataset.drop(train_dataset.index)\n",
    "\n",
    "    return {\"train\": train_dataset, \"test\": test_dataset}\n",
    "\n",
    "\n",
    "def linear_model(data_sets: dict) -> dict:\n",
    "    train_dataset = data_sets[\"train\"]\n",
    "    # Fit the model\n",
    "    ## config for fitting a model\n",
    "    target_column: str = \"MPG\"\n",
    "\n",
    "    ## fit a model\n",
    "    # pull out target\n",
    "    train_labels = train_dataset.pop(target_column)\n",
    "    # Convert boolean columns to integers for the model\n",
    "    bool_columns = train_dataset.select_dtypes(include=[bool]).columns\n",
    "    train_dataset[bool_columns] = train_dataset[bool_columns].astype(int)\n",
    "    # Normalize the features for the model\n",
    "    scaler = StandardScaler()\n",
    "    train_dataset_scaled = scaler.fit_transform(train_dataset)\n",
    "    \n",
    "    # Initialize and fit the Linear Regression model\n",
    "    linear_model = LinearRegression()\n",
    "    linear_model.fit(train_dataset_scaled, train_labels)\n",
    "    return {\"linear_model\": linear_model, \"scaler\": scaler}\n",
    "\n",
    "def evaluated_model(linear_model: dict, data_sets: dict) -> dict:\n",
    "    test_dataset = data_sets[\"test\"]\n",
    "    target_column: str = \"MPG\"\n",
    "    # evaluate the model - pull out target\n",
    "    test_labels = test_dataset.pop(target_column)\n",
    "    # Evaluate the model\n",
    "        # convert boolean columns to integers for the model\n",
    "    bool_columns = test_dataset.select_dtypes(include=[bool]).columns\n",
    "    test_dataset[bool_columns] = test_dataset[bool_columns].astype(int)\n",
    "    test_dataset_scaled = linear_model[\"scaler\"].transform(test_dataset)\n",
    "    \n",
    "    # Predict and evaluate the model\n",
    "    test_pred = linear_model[\"linear_model\"].predict(test_dataset_scaled)\n",
    "    mae = mean_absolute_error(test_labels, test_pred)\n",
    "    test_results = {\n",
    "        \"linear_model\": mae\n",
    "    }\n",
    "    return test_results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:04.415885Z",
     "start_time": "2024-07-20T17:37:04.097149Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "17b10355-4e25-4e75-84c5-fd95c0bd3dfb",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"512pt\" height=\"195pt\"\n",
       " viewBox=\"0.00 0.00 511.65 194.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 190.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-190.8 507.65,-190.8 507.65,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"8,-103.8 8,-178.8 92.85,-178.8 92.85,-103.8 8,-103.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"50.43\" y=\"-161.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- mpg_df -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>mpg_df</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M83.35,-93.6C83.35,-93.6 17.5,-93.6 17.5,-93.6 11.5,-93.6 5.5,-87.6 5.5,-81.6 5.5,-81.6 5.5,-42 5.5,-42 5.5,-36 11.5,-30 17.5,-30 17.5,-30 83.35,-30 83.35,-30 89.35,-30 95.35,-36 95.35,-42 95.35,-42 95.35,-81.6 95.35,-81.6 95.35,-87.6 89.35,-93.6 83.35,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"25.68\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n",
       "<text text-anchor=\"start\" x=\"16.3\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>data_sets</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M196.2,-93.6C196.2,-93.6 136.35,-93.6 136.35,-93.6 130.35,-93.6 124.35,-87.6 124.35,-81.6 124.35,-81.6 124.35,-42 124.35,-42 124.35,-36 130.35,-30 136.35,-30 136.35,-30 196.2,-30 196.2,-30 202.2,-30 208.2,-36 208.2,-42 208.2,-42 208.2,-81.6 208.2,-81.6 208.2,-87.6 202.2,-93.6 196.2,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"135.15\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n",
       "<text text-anchor=\"start\" x=\"155.78\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- mpg_df&#45;&gt;data_sets -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>mpg_df&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M95.77,-61.8C101.33,-61.8 107.06,-61.8 112.73,-61.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"112.7,-65.3 122.7,-61.8 112.7,-58.3 112.7,-65.3\"/>\n",
       "</g>\n",
       "<!-- evaluated_model -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>evaluated_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M491.65,-93.6C491.65,-93.6 383.8,-93.6 383.8,-93.6 377.8,-93.6 371.8,-87.6 371.8,-81.6 371.8,-81.6 371.8,-42 371.8,-42 371.8,-36 377.8,-30 383.8,-30 383.8,-30 491.65,-30 491.65,-30 497.65,-30 503.65,-36 503.65,-42 503.65,-42 503.65,-81.6 503.65,-81.6 503.65,-87.6 497.65,-93.6 491.65,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"382.6\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n",
       "<text text-anchor=\"start\" x=\"427.23\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M208.56,-69.39C217.94,-70.81 227.89,-72.07 237.2,-72.8 283.99,-76.46 295.95,-75.53 342.8,-72.8 348.44,-72.47 354.27,-72.03 360.12,-71.52\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"360.25,-75.02 369.87,-70.59 359.58,-68.05 360.25,-75.02\"/>\n",
       "</g>\n",
       "<!-- linear_model -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>linear_model</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M330.8,-63.6C330.8,-63.6 249.2,-63.6 249.2,-63.6 243.2,-63.6 237.2,-57.6 237.2,-51.6 237.2,-51.6 237.2,-12 237.2,-12 237.2,-6 243.2,0 249.2,0 249.2,0 330.8,0 330.8,0 336.8,0 342.8,-6 342.8,-12 342.8,-12 342.8,-51.6 342.8,-51.6 342.8,-57.6 336.8,-63.6 330.8,-63.6\"/>\n",
       "<text text-anchor=\"start\" x=\"248\" y=\"-40.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n",
       "<text text-anchor=\"start\" x=\"279.5\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;linear_model -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;linear_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M208.49,-51.64C214.03,-50.28 219.83,-48.85 225.65,-47.41\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"226.48,-50.81 235.36,-45.02 224.81,-44.02 226.48,-50.81\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M343.2,-42.55C348.77,-43.69 354.52,-44.88 360.3,-46.07\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"359.43,-49.46 369.93,-48.05 360.84,-42.61 359.43,-49.46\"/>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M72.85,-148.1C72.85,-148.1 28,-148.1 28,-148.1 22,-148.1 16,-142.1 16,-136.1 16,-136.1 16,-123.5 16,-123.5 16,-117.5 22,-111.5 28,-111.5 28,-111.5 72.85,-111.5 72.85,-111.5 78.85,-111.5 84.85,-117.5 84.85,-123.5 84.85,-123.5 84.85,-136.1 84.85,-136.1 84.85,-142.1 78.85,-148.1 72.85,-148.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"50.43\" y=\"-124\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<hamilton.driver.Driver at 0x15043e8b0>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dr = driver.Builder().with_modules(pipeline).build()\n",
    "dr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:06.029923Z",
     "start_time": "2024-07-20T17:37:05.934165Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "923eff9a-ce20-484e-a4c7-acfbebb58e16",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'evaluated_model': {'linear_model': 2.4926580150007007},\n",
       " 'linear_model': {'linear_model': LinearRegression(),\n",
       "  'scaler': StandardScaler()}}"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "result = dr.execute([\"evaluated_model\", \"linear_model\"])\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:36.535977Z",
     "start_time": "2024-07-20T17:37:36.241822Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "399b0164-819e-4c8a-a46d-021742ace28e",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "image/svg+xml": [
       "<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
       "<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
       " \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
       "<!-- Generated by graphviz version 10.0.1 (20240210.2158)\n",
       " -->\n",
       "<!-- Pages: 1 -->\n",
       "<svg width=\"512pt\" height=\"305pt\"\n",
       " viewBox=\"0.00 0.00 512.40 304.80\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
       "<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 300.8)\">\n",
       "<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-300.8 508.4,-300.8 508.4,4 -4,4\"/>\n",
       "<g id=\"clust1\" class=\"cluster\">\n",
       "<title>cluster__legend</title>\n",
       "<polygon fill=\"#ffffff\" stroke=\"black\" points=\"8,-103.8 8,-288.8 94.35,-288.8 94.35,-103.8 8,-103.8\"/>\n",
       "<text text-anchor=\"middle\" x=\"51.18\" y=\"-271.5\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">Legend</text>\n",
       "</g>\n",
       "<!-- mpg_df -->\n",
       "<g id=\"node1\" class=\"node\">\n",
       "<title>mpg_df</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M84.1,-93.6C84.1,-93.6 18.25,-93.6 18.25,-93.6 12.25,-93.6 6.25,-87.6 6.25,-81.6 6.25,-81.6 6.25,-42 6.25,-42 6.25,-36 12.25,-30 18.25,-30 18.25,-30 84.1,-30 84.1,-30 90.1,-30 96.1,-36 96.1,-42 96.1,-42 96.1,-81.6 96.1,-81.6 96.1,-87.6 90.1,-93.6 84.1,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"26.43\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">mpg_df</text>\n",
       "<text text-anchor=\"start\" x=\"17.05\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">DataFrame</text>\n",
       "</g>\n",
       "<!-- data_sets -->\n",
       "<g id=\"node2\" class=\"node\">\n",
       "<title>data_sets</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M196.95,-93.6C196.95,-93.6 137.1,-93.6 137.1,-93.6 131.1,-93.6 125.1,-87.6 125.1,-81.6 125.1,-81.6 125.1,-42 125.1,-42 125.1,-36 131.1,-30 137.1,-30 137.1,-30 196.95,-30 196.95,-30 202.95,-30 208.95,-36 208.95,-42 208.95,-42 208.95,-81.6 208.95,-81.6 208.95,-87.6 202.95,-93.6 196.95,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"135.9\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">data_sets</text>\n",
       "<text text-anchor=\"start\" x=\"156.53\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- mpg_df&#45;&gt;data_sets -->\n",
       "<g id=\"edge1\" class=\"edge\">\n",
       "<title>mpg_df&#45;&gt;data_sets</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M96.52,-61.8C102.08,-61.8 107.81,-61.8 113.48,-61.8\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"113.45,-65.3 123.45,-61.8 113.45,-58.3 113.45,-65.3\"/>\n",
       "</g>\n",
       "<!-- evaluated_model -->\n",
       "<g id=\"node3\" class=\"node\">\n",
       "<title>evaluated_model</title>\n",
       "<path fill=\"#ffc857\" stroke=\"black\" d=\"M492.4,-93.6C492.4,-93.6 384.55,-93.6 384.55,-93.6 378.55,-93.6 372.55,-87.6 372.55,-81.6 372.55,-81.6 372.55,-42 372.55,-42 372.55,-36 378.55,-30 384.55,-30 384.55,-30 492.4,-30 492.4,-30 498.4,-30 504.4,-36 504.4,-42 504.4,-42 504.4,-81.6 504.4,-81.6 504.4,-87.6 498.4,-93.6 492.4,-93.6\"/>\n",
       "<text text-anchor=\"start\" x=\"383.35\" y=\"-70.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">evaluated_model</text>\n",
       "<text text-anchor=\"start\" x=\"427.98\" y=\"-42.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge3\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M209.31,-69.39C218.69,-70.81 228.64,-72.07 237.95,-72.8 284.74,-76.46 296.7,-75.53 343.55,-72.8 349.19,-72.47 355.02,-72.03 360.87,-71.52\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"361,-75.02 370.62,-70.59 360.33,-68.05 361,-75.02\"/>\n",
       "</g>\n",
       "<!-- linear_model -->\n",
       "<g id=\"node4\" class=\"node\">\n",
       "<title>linear_model</title>\n",
       "<polygon fill=\"#b4d8e4\" stroke=\"black\" points=\"343.55,-63.6 237.95,-63.6 237.95,0 343.55,0 343.55,-63.6\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"249.95,-63.6 237.95,-51.6\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"237.95,-12 249.95,0\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"331.55,0 343.55,-12\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"343.55,-51.6 331.55,-63.6\"/>\n",
       "<text text-anchor=\"start\" x=\"248.75\" y=\"-40.5\" font-family=\"Helvetica,sans-Serif\" font-weight=\"bold\" font-size=\"14.00\">linear_model</text>\n",
       "<text text-anchor=\"start\" x=\"280.25\" y=\"-12.5\" font-family=\"Helvetica,sans-Serif\" font-style=\"italic\" font-size=\"14.00\">dict</text>\n",
       "</g>\n",
       "<!-- data_sets&#45;&gt;linear_model -->\n",
       "<g id=\"edge4\" class=\"edge\">\n",
       "<title>data_sets&#45;&gt;linear_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M209.24,-51.64C214.78,-50.28 220.58,-48.85 226.4,-47.41\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"227.23,-50.81 236.11,-45.02 225.56,-44.02 227.23,-50.81\"/>\n",
       "</g>\n",
       "<!-- linear_model&#45;&gt;evaluated_model -->\n",
       "<g id=\"edge2\" class=\"edge\">\n",
       "<title>linear_model&#45;&gt;evaluated_model</title>\n",
       "<path fill=\"none\" stroke=\"black\" d=\"M343.95,-42.55C349.52,-43.69 355.27,-44.88 361.05,-46.07\"/>\n",
       "<polygon fill=\"black\" stroke=\"black\" points=\"360.18,-49.46 370.68,-48.05 361.59,-42.61 360.18,-49.46\"/>\n",
       "</g>\n",
       "<!-- function -->\n",
       "<g id=\"node5\" class=\"node\">\n",
       "<title>function</title>\n",
       "<path fill=\"#b4d8e4\" stroke=\"black\" d=\"M73.6,-148.1C73.6,-148.1 28.75,-148.1 28.75,-148.1 22.75,-148.1 16.75,-142.1 16.75,-136.1 16.75,-136.1 16.75,-123.5 16.75,-123.5 16.75,-117.5 22.75,-111.5 28.75,-111.5 28.75,-111.5 73.6,-111.5 73.6,-111.5 79.6,-111.5 85.6,-117.5 85.6,-123.5 85.6,-123.5 85.6,-136.1 85.6,-136.1 85.6,-142.1 79.6,-148.1 73.6,-148.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"51.18\" y=\"-124\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">function</text>\n",
       "</g>\n",
       "<!-- output -->\n",
       "<g id=\"node6\" class=\"node\">\n",
       "<title>output</title>\n",
       "<path fill=\"#ffc857\" stroke=\"black\" d=\"M68.73,-203.1C68.73,-203.1 33.63,-203.1 33.63,-203.1 27.63,-203.1 21.63,-197.1 21.63,-191.1 21.63,-191.1 21.63,-178.5 21.63,-178.5 21.63,-172.5 27.63,-166.5 33.63,-166.5 33.63,-166.5 68.73,-166.5 68.73,-166.5 74.73,-166.5 80.73,-172.5 80.73,-178.5 80.73,-178.5 80.73,-191.1 80.73,-191.1 80.73,-197.1 74.73,-203.1 68.73,-203.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"51.18\" y=\"-179\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">output</text>\n",
       "</g>\n",
       "<!-- override -->\n",
       "<g id=\"node7\" class=\"node\">\n",
       "<title>override</title>\n",
       "<polygon fill=\"#b4d8e4\" stroke=\"black\" points=\"86.35,-258.1 16,-258.1 16,-221.5 86.35,-221.5 86.35,-258.1\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"28,-258.1 16,-246.1\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"16,-233.5 28,-221.5\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"74.35,-221.5 86.35,-233.5\"/>\n",
       "<polyline fill=\"none\" stroke=\"black\" points=\"86.35,-246.1 74.35,-258.1\"/>\n",
       "<text text-anchor=\"middle\" x=\"51.18\" y=\"-234\" font-family=\"Helvetica,sans-Serif\" font-size=\"14.00\">override</text>\n",
       "</g>\n",
       "</g>\n",
       "</svg>\n"
      ],
      "text/plain": [
       "<graphviz.graphs.Digraph at 0x15043edc0>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Visualize Overrides\n",
    "dr.visualize_execution([\"evaluated_model\"], \n",
    "                       overrides={\"linear_model\": result[\"linear_model\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-07-20T17:37:37.088527Z",
     "start_time": "2024-07-20T17:37:37.025195Z"
    },
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "b617038c-2ffa-498c-b6c9-bbd6bb79bb1f",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'evaluated_model': {'linear_model': 2.4926580150007007}}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Execute with overrides\n",
    "dr.execute([\"evaluated_model\"], overrides={\"linear_model\": result[\"linear_model\"]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "language": "python",
   "notebookMetadata": {
    "mostRecentlyExecutedCommandWithImplicitDF": {
     "commandId": 2746022128672016,
     "dataframes": [
      "_sqldf"
     ]
    },
    "pythonIndentUnit": 4
   },
   "notebookName": "MPG Simple V1 Target",
   "widgets": {}
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
